{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install TensorTrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python3 -m pip install git+https://github.com/tensortrade-org/tensortrade.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Data Fetching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import tensortrade.env.default as default\n",
    "\n",
    "from tensortrade.data.cdd import CryptoDataDownload\n",
    "from tensortrade.feed.core import Stream, DataFeed\n",
    "from tensortrade.oms.exchanges import Exchange\n",
    "from tensortrade.oms.services.execution.simulated import execute_order\n",
    "from tensortrade.oms.instruments import USD, BTC, ETH\n",
    "from tensortrade.oms.wallets import Wallet, Portfolio\n",
    "from tensortrade.agents import DQNAgent\n",
    "\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cdd = CryptoDataDownload()\n",
    "\n",
    "data = cdd.fetch(\"Bitstamp\", \"USD\", \"BTC\", \"1h\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>open</th>\n",
       "      <th>high</th>\n",
       "      <th>low</th>\n",
       "      <th>close</th>\n",
       "      <th>volume</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2017-07-01 11:00:00</td>\n",
       "      <td>2505.56</td>\n",
       "      <td>2513.38</td>\n",
       "      <td>2495.12</td>\n",
       "      <td>2509.17</td>\n",
       "      <td>287000.32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2017-07-01 12:00:00</td>\n",
       "      <td>2509.17</td>\n",
       "      <td>2512.87</td>\n",
       "      <td>2484.99</td>\n",
       "      <td>2488.43</td>\n",
       "      <td>393142.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2017-07-01 13:00:00</td>\n",
       "      <td>2488.43</td>\n",
       "      <td>2488.43</td>\n",
       "      <td>2454.40</td>\n",
       "      <td>2454.43</td>\n",
       "      <td>693254.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2017-07-01 14:00:00</td>\n",
       "      <td>2454.43</td>\n",
       "      <td>2473.93</td>\n",
       "      <td>2450.83</td>\n",
       "      <td>2459.35</td>\n",
       "      <td>712864.80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2017-07-01 15:00:00</td>\n",
       "      <td>2459.35</td>\n",
       "      <td>2475.00</td>\n",
       "      <td>2450.00</td>\n",
       "      <td>2467.83</td>\n",
       "      <td>682105.41</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 date     open     high      low    close     volume\n",
       "0 2017-07-01 11:00:00  2505.56  2513.38  2495.12  2509.17  287000.32\n",
       "1 2017-07-01 12:00:00  2509.17  2512.87  2484.99  2488.43  393142.50\n",
       "2 2017-07-01 13:00:00  2488.43  2488.43  2454.40  2454.43  693254.01\n",
       "3 2017-07-01 14:00:00  2454.43  2473.93  2450.83  2459.35  712864.80\n",
       "4 2017-07-01 15:00:00  2459.35  2475.00  2450.00  2467.83  682105.41"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create features with the feed module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rsi(price: Stream[float], period: float) -> Stream[float]:\n",
    "    r = price.diff()\n",
    "    upside = r.clamp_min(0).abs()\n",
    "    downside = r.clamp_max(0).abs()\n",
    "    rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()\n",
    "    return 100*(1 - (1 + rs) ** -1)\n",
    "\n",
    "\n",
    "def macd(price: Stream[float], fast: float, slow: float, signal: float) -> Stream[float]:\n",
    "    fm = price.ewm(span=fast, adjust=False).mean()\n",
    "    sm = price.ewm(span=slow, adjust=False).mean()\n",
    "    md = fm - sm\n",
    "    signal = md - md.ewm(span=signal, adjust=False).mean()\n",
    "    return signal\n",
    "\n",
    "\n",
    "features = []\n",
    "for c in data.columns[1:]:\n",
    "    s = Stream.source(list(data[c]), dtype=\"float\").rename(data[c].name)\n",
    "    features += [s]\n",
    "\n",
    "cp = Stream.select(features, lambda s: s.name == \"close\")\n",
    "\n",
    "features = [\n",
    "    cp.log().diff().rename(\"lr\"),\n",
    "    rsi(cp, period=20).rename(\"rsi\"),\n",
    "    macd(cp, fast=10, slow=50, signal=5).rename(\"macd\")\n",
    "]\n",
    "\n",
    "feed = DataFeed(features)\n",
    "feed.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'lr': nan, 'rsi': nan, 'macd': 0.0}\n",
      "{'lr': -0.008300031641449657, 'rsi': 0.0, 'macd': -1.9717171717171975}\n",
      "{'lr': -0.01375743446296962, 'rsi': 0.0, 'macd': -6.082702245269603}\n",
      "{'lr': 0.0020025323250756344, 'rsi': 8.795475693113076, 'macd': -7.287625162566419}\n",
      "{'lr': 0.00344213459739251, 'rsi': 21.34663357024277, 'macd': -6.522181201739986}\n"
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    print(feed.next())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Trading Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bitstamp = Exchange(\"bitstamp\", service=execute_order)(\n",
    "    Stream.source(list(data[\"close\"]), dtype=\"float\").rename(\"USD-BTC\")\n",
    ")\n",
    "\n",
    "portfolio = Portfolio(USD, [\n",
    "    Wallet(bitstamp, 10000 * USD),\n",
    "    Wallet(bitstamp, 10 * BTC)\n",
    "])\n",
    "\n",
    "\n",
    "renderer_feed = DataFeed([\n",
    "    Stream.source(list(data[\"date\"])).rename(\"date\"),\n",
    "    Stream.source(list(data[\"open\"]), dtype=\"float\").rename(\"open\"),\n",
    "    Stream.source(list(data[\"high\"]), dtype=\"float\").rename(\"high\"),\n",
    "    Stream.source(list(data[\"low\"]), dtype=\"float\").rename(\"low\"),\n",
    "    Stream.source(list(data[\"close\"]), dtype=\"float\").rename(\"close\"), \n",
    "    Stream.source(list(data[\"volume\"]), dtype=\"float\").rename(\"volume\") \n",
    "])\n",
    "\n",
    "\n",
    "env = default.create(\n",
    "    portfolio=portfolio,\n",
    "    action_scheme=\"managed-risk\",\n",
    "    reward_scheme=\"risk-adjusted\",\n",
    "    feed=feed,\n",
    "    renderer_feed=renderer_feed,\n",
    "    renderer=default.renderers.PlotlyTradingChart(),\n",
    "    window_size=20\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'internal': {'bitstamp:/USD-BTC': 2509.17,\n",
       "  'bitstamp:/USD:/free': 10000.0,\n",
       "  'bitstamp:/USD:/locked': 0.0,\n",
       "  'bitstamp:/USD:/total': 10000.0,\n",
       "  'bitstamp:/BTC:/free': 10.0,\n",
       "  'bitstamp:/BTC:/locked': 0.0,\n",
       "  'bitstamp:/BTC:/total': 10.0,\n",
       "  'bitstamp:/BTC:/worth': 25091.7,\n",
       "  'net_worth': 35091.7},\n",
       " 'external': {'lr': nan, 'rsi': nan, 'macd': 0.0},\n",
       " 'renderer': {'date': Timestamp('2017-07-01 11:00:00'),\n",
       "  'open': 2505.56,\n",
       "  'high': 2513.38,\n",
       "  'low': 2495.12,\n",
       "  'close': 2509.17,\n",
       "  'volume': 287000.32}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.observer.feed.next()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup and Train DQN Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00e249aae8ab415999b852a8f0748f79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FigureWidget({\n",
       "    'data': [{'close': array([2509.17, 2488.43, 2454.43, ..., 2553.79, 2539.82, 2542.72]),\n",
       "    …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "49999.36214535988"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent = DQNAgent(env)\n",
    "\n",
    "agent.train(n_steps=200, n_episodes=2, save_path=\"agents/\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
